{
 "cells": [
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "# Anchor explanations on the Iris dataset"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 1,
   "metadata": {},
   "outputs": [],
   "source": [
    "import numpy as np\n",
    "from sklearn.datasets import load_iris\n",
    "from sklearn.ensemble import RandomForestClassifier\n",
    "from alibi.explainers import AnchorTabular"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### Load iris dataset"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 2,
   "metadata": {},
   "outputs": [],
   "source": [
    "dataset = load_iris()\n",
    "feature_names = dataset.feature_names\n",
    "class_names = list(dataset.target_names)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "Define training and test set"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 3,
   "metadata": {},
   "outputs": [],
   "source": [
    "idx = 145\n",
    "X_train,Y_train = dataset.data[:idx,:], dataset.target[:idx]\n",
    "X_test, Y_test = dataset.data[idx+1:,:], dataset.target[idx+1:]"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### Train Random Forest model"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 4,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "RandomForestClassifier(bootstrap=True, ccp_alpha=0.0, class_weight=None,\n",
       "                       criterion='gini', max_depth=None, max_features='auto',\n",
       "                       max_leaf_nodes=None, max_samples=None,\n",
       "                       min_impurity_decrease=0.0, min_impurity_split=None,\n",
       "                       min_samples_leaf=1, min_samples_split=2,\n",
       "                       min_weight_fraction_leaf=0.0, n_estimators=50,\n",
       "                       n_jobs=None, oob_score=False, random_state=None,\n",
       "                       verbose=0, warm_start=False)"
      ]
     },
     "execution_count": 4,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "np.random.seed(0)\n",
    "clf = RandomForestClassifier(n_estimators=50)\n",
    "clf.fit(X_train, Y_train)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "Define predict function"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 5,
   "metadata": {},
   "outputs": [],
   "source": [
    "predict_fn = lambda x: clf.predict_proba(x)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### Initialize and fit anchor explainer for tabular data"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 6,
   "metadata": {},
   "outputs": [],
   "source": [
    "explainer = AnchorTabular(predict_fn, feature_names)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "Discretize the ordinal features into quartiles"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 7,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "AnchorTabular(meta={\n",
       "    'name': 'AnchorTabular',\n",
       "    'type': ['blackbox'],\n",
       "    'explanations': ['local'],\n",
       "    'params': {'seed': None, 'disc_perc': (25, 50, 75)}\n",
       "})"
      ]
     },
     "execution_count": 7,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "explainer.fit(X_train, disc_perc=(25, 50, 75))"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### Getting an anchor\n",
    "\n",
    "Below, we get an anchor for the prediction of the first observation in the test set. An anchor is a sufficient condition - that is, when the anchor holds, the prediction should be the same as the prediction for this instance."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 8,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Prediction:  virginica\n"
     ]
    }
   ],
   "source": [
    "idx = 0\n",
    "print('Prediction: ', class_names[explainer.predictor(X_test[idx].reshape(1, -1))[0]])"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "We set the precision threshold to 0.95. This means that predictions on observations where the anchor holds will be the same as the prediction on the explained instance at least 95% of the time."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 9,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Anchor: petal width (cm) > 1.80 AND sepal width (cm) <= 2.80\n",
      "Precision: 0.98\n",
      "Coverage: 0.32\n"
     ]
    }
   ],
   "source": [
    "explanation = explainer.explain(X_test[idx], threshold=0.95)\n",
    "print('Anchor: %s' % (' AND '.join(explanation.anchor)))\n",
    "print('Precision: %.2f' % explanation.precision)\n",
    "print('Coverage: %.2f' % explanation.coverage)"
   ]
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python 3",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.10.14"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 4
}
